#!/usr/bin/python3
import math
import numpy as np
import matplotlib.pyplot as plt
from numpy.core.fromnumeric import std

def gauss(n, mean, std):
    x = np.arange(n)
    var = std**2
    y = np.exp(-(x-mean)**2/(2*var))/math.sqrt(2*math.pi*var)
    return y

def norm(x):
    x = x.astype(np.float)
    x -= x.mean()
    x /= x.std()
    return x

class LocalProc:
    def __init__(self, data):
        self.data = data

    def find_split(self, n=32, spt=0.1):
        def find_cut(data, spt=0.1, side='left'):
            s = np.cumsum(data)
            s1 = s[-1]*(1-spt)
            x = np.searchsorted(s, s1, side=side)
            return x
        peak = self.data.argmax()
        l = find_cut(self.data[peak:peak-n:-1], spt=spt, side='left')
        r = find_cut(self.data[peak:peak+n], spt=spt, side='right')
        self.l = peak-l+1  # self.afft remove DC
        self.r = peak+r+1
        return self.l, self.r

    def find_split2(self, n=32, spt=0.1):
        """
        对称搜索切分点
        如果大部分数据聚集在很小的区域内(<20个点)，不建议使用这个函数，会导致估计错误。
        """
        def find_cut(data, spt=0.1):
            s = np.cumsum(data)
            s1 = s[0]*spt + s[-1]*(1-spt)
            return s, s1
        peak = self.data.argmax()
        data_l = self.data[peak:peak-n:-1]
        data_r = self.data[peak:peak+n]
        Sl, sl = find_cut(data_l, spt=spt)
        Sr, sr = find_cut(data_r, spt=spt)
        scut = (sl+sr)/2
        l = np.searchsorted(Sl, scut, side='left')
        r = np.searchsorted(Sr, scut, side='right')
        self.l = peak-l+1  # self.afft remove DC
        self.r = peak+r+1
        return self.l, self.r

    def local_msx(self):
        cut = self.data[self.l:self.r]
        s = sum(cut)
        s2 = sum(map(lambda x:x**2, cut))
        EX = sum(map(lambda x:x[0]*x[1], enumerate(cut)))/s
        EX2 = sum(map(lambda x:(x[0]**2)*x[1], enumerate(cut)))/s
        std = math.sqrt(EX2-EX**2)
        E2X = sum(map(lambda x:x[0]*(x[1])**2, enumerate(cut)))/s2
        self.mean = EX+self.l
        self.std = std
        self.peak = E2X+self.l
        return self.mean, self.std, self.peak


class InvH:
    def __init__(self, ref):
        self.ref = ref
        self.pre_fft()

    def pre_fft(self):
        self.rfft = np.fft.rfft(ref)
        self.afft = np.abs(self.rfft)[1:]

    def lp_init(self, n=25, spt=0.1):
        """
        Gasuss: n=25, spt=0.1
        Hanning: n=30, spt=0.02
        """
        self.lp = LocalProc(self.afft)
        self.lp.find_split(n=n, spt=spt)
        self.lp.local_msx()

    def Gasuss_H1(self, mean=None, std=None, std_mul=1.5):
        if mean is None:
            mean = self.lp.peak
        if std is None:
            std = self.lp.std*std_mul
        win = gauss(self.rfft.shape[0], mean, std)
        self.H1 = win/self.rfft
        self.H1 *= self.afft.max()/win.max()/2 #为了方便对比波形, /2不会遮盖原波形
        return self.H1

    def Hanning_H1(self, l=None, r=None):
        if l is None:
            l = self.lp.l
        if r is None:
            r = self.lp.r
        rff2 = self.rfft[l:r]
        win = np.hanning(rff2.shape[0])
        iH = win/rff2
        self.H1 = np.pad(iH, (l, self.rfft.shape[0]-l-iH.shape[0]))
        return self.H1

    def limit_H1h1(self, max_abs=None):
        """
        doesnt advice use, else wave will be bad,
        advice keep `max_abs=None`, but the func must call
        """
        if max_abs is not None:
            ab1 = np.abs(self.H1)
            ang = np.angle(self.H1)
            ab2 = np.clip(ab1, 0, max_abs/self.H1.shape[0])
            self.H1 = ab2 * np.exp(1j*ang)
            plt.plot(ab1)
            plt.plot(ab2)
            plt.show()
        self.h1 = np.fft.irfft(self.H1)

    def win_h1(self, n=50, spt=0.1, s_mul=5):
        lp = LocalProc(self.h1**2)
        lp.find_split2(n=n, spt=spt)
        m, s, p = lp.local_msx()
        s *= s_mul
        n2 = self.h1.shape[0]//2
        self.h1 = np.roll(self.h1, n2-int(m))
        self.h1 = self.h1[n2-int(s):n2+int(s)]
        self.h1 *= np.hanning(self.h1.shape[0])

    def update_invH(self, n):
        h2 = np.pad(self.h1, (0, n-self.h1.shape[0]))
        self.invH = np.fft.rfft(h2)
        return self.invH

    def correct_dt(self, dt=None):
        if dt is None:
            lp = LocalProc(self.h1**2)
            lp.find_split(n=30, spt=0.1)
            _, _, dt = lp.local_msx()
        w = np.linspace(0, math.pi, num=self.invH.shape[0], endpoint=True)
        self.invH *= np.exp(1j*w*dt)
        return dt

    def calcx(self, y):
        Y = np.fft.rfft(y)
        X = Y*self.invH
        x = np.fft.irfft(X)
        return x


ref = np.load('save1.npy')[:512]
ih = InvH(ref)

ih.lp_init(n=25, spt=0.1)
ih.Gasuss_H1(std_mul=1.8)
#ih.lp_init(n=30, spt=0.02)
#ih.Hanning_H1()

ih.limit_H1h1()
ih.win_h1(n=100)
h1 = ih.update_invH(n=4096)
dt = ih.correct_dt()

if __name__ == '__main__':
    wave = np.load('save.npy')
    wave = wave - wave.mean()
    x = ih.calcx(wave)
    #plot
    plt.subplot(211)
    plt.plot(ih.h1)
    plt.subplot(212)
    plt.plot(wave)
    plt.plot(x)
    plt.show()
